#----------------------------------------------------------------------
#  GFDM method test - 2d Navier-Cauchy equation
#  Square plate with circular hole
#  Author: Andrea Pavan
#  Date: 22/12/2022
#  License: GPLv3-or-later
#----------------------------------------------------------------------
using ElasticArrays;
using LinearAlgebra;
using SparseArrays;
using PyPlot;
include("utils.jl");


#problem definition
l1 = 5.0;       #domain x size
l2 = 5.0;       #domain y size
a = 1.0;        #hole radius
σ0 = 1.0;       #traction load
E = 1000;       #Young modulus
ν = 0.3;        #Poisson ratio

meshSize = 0.2;
surfaceMeshSize = meshSize;
minNeighbors = 10;
minSearchRadius = meshSize/2;

function exact_σxx(x,y)
    rcoord = sqrt(x^2+y^2);
    θ = atan(y,x);
    return σ0*(1- ((a^2)/(rcoord^2))*(1.5*cos(2*θ)+cos(4*θ))+1.5*((a^4)/(rcoord^4))*cos(4*θ));
end
μ = 0.5*E/(1+ν);      #Lamè coefficients
λ = E*ν/((1+ν)*(1-2*ν));


#read pointcloud from a SU2 file
time1 = time();
pointcloud = ElasticArray{Float64}(undef,2,0);      #2xN matrix containing the coordinates [X;Y] of each node
boundaryNodes = Vector{Int}(undef,0);       #indices of the boundary nodes
internalNodes = Vector{Int}(undef,0);       #indices of the internal nodes
normals = ElasticArray{Float64}(undef,2,0);     #2xN matrix containing the components [nx;ny] of the normal of each boundary node

pointcloud = parseSU2mesh("12b_direct_2d_hole_plate_su2_mesh_772.su2");
#pointcloud = parseSU2mesh("12b_direct_2d_hole_plate_su2_mesh_2951.su2");
cornerPoint = findall((pointcloud[1,:].==a).*(pointcloud[2,:].==0));
#deleteat!.(pointcloud, Ref(cornerPoint));
#pointcloud = pointcloud[:, 1:end .!= cornerPoint];
pointcloud = pointcloud[:, setdiff(1:end,cornerPoint)];
cornerPoint = findall((pointcloud[1,:].==l1).*(pointcloud[2,:].==0));
#pointcloud = pointcloud[:, 1:end .!= cornerPoint];
pointcloud = pointcloud[:, setdiff(1:end,cornerPoint)];
cornerPoint = findall((pointcloud[1,:].==l1).*(pointcloud[2,:].==l2));
#pointcloud = pointcloud[:, 1:end .!= cornerPoint];
pointcloud = pointcloud[:, setdiff(1:end,cornerPoint)];
cornerPoint = findall((pointcloud[1,:].==0).*(pointcloud[2,:].==l2));
#pointcloud = pointcloud[:, 1:end .!= cornerPoint];
pointcloud = pointcloud[:, setdiff(1:end,cornerPoint)];
cornerPoint = findall((pointcloud[1,:].==0).*(pointcloud[2,:].==a));
#pointcloud = pointcloud[:, 1:end .!= cornerPoint];
pointcloud = pointcloud[:, setdiff(1:end,cornerPoint)];
cornerPoint = findall((pointcloud[1,:].==0).*(pointcloud[2,:].==0));
#pointcloud = pointcloud[:, 1:end .!= cornerPoint];
pointcloud = pointcloud[:, setdiff(1:end,cornerPoint)];
N = size(pointcloud,2);

for i=1:N
    if pointcloud[2,i]==0
        push!(boundaryNodes, i);
        append!(normals, [0,-1]);
    elseif pointcloud[1,i]==l1
        push!(boundaryNodes, i);
        append!(normals, [1,0]);
    elseif pointcloud[2,i]==l2
        push!(boundaryNodes, i);
        append!(normals, [0,1]);
    elseif pointcloud[1,i]==0
        push!(boundaryNodes, i);
        append!(normals, [-1,0]);
    elseif (pointcloud[1,i].^2+pointcloud[2,i].^2).<=(a+1e-3)^2
        push!(boundaryNodes, i);
        append!(normals, -pointcloud[:,i]/a);
    else
        push!(internalNodes, i);
    end
end

println("Generated pointcloud in ", round(time()-time1,digits=2), " s");
println("Pointcloud properties:");
println("  Boundary nodes: ",length(boundaryNodes));
println("  Internal nodes: ",length(internalNodes));
println("  Memory: ",memoryUsage(pointcloud,boundaryNodes));

#pointcloud plot
figure();
plot(pointcloud[1,boundaryNodes],pointcloud[2,boundaryNodes],"r.");
plot(pointcloud[1,internalNodes],pointcloud[2,internalNodes],"k.");
title("Pointcloud plot");
axis("equal");
display(gcf());


#boundary conditions
N = size(pointcloud,2);     #number of nodes
uD = Vector{Float64}(undef,N);
vD = Vector{Float64}(undef,N);
uN = Vector{Float64}(undef,N);
vN = Vector{Float64}(undef,N);
for i=1:N
    uD[i] = NaN;
    uN[i] = 0.0;
    vD[i] = NaN;
    vN[i] = 0.0;
end
for i in boundaryNodes
    if pointcloud[2,i]==0
        #bottom
        vD[i] = 0.0;
    end
    if pointcloud[1,i]==l1
        #right
        uN[i] = σ0;
    end
    if pointcloud[1,i]==0
        #left
        uD[i] = 0.0;
    end
end


#neighbor search
time2 = time();
N = size(pointcloud,2);     #number of nodes
neighbors = Vector{Vector{Int}}(undef,N);       #vector containing N vectors of the indices of each node neighbors
Nneighbors = zeros(Int,N);      #number of neighbors of each node
for i=1:N
    searchradius = minSearchRadius;
    while Nneighbors[i]<minNeighbors
        neighbors[i] = Int[];
        #check every other node
        for j=1:N
            if j!=i && all(abs.(pointcloud[:,j]-pointcloud[:,i]).<searchradius)
                push!(neighbors[i],j);
            end
        end
        unique!(neighbors[i]);
        Nneighbors[i] = length(neighbors[i]);
        searchradius += minSearchRadius/2;
    end
end
println("Found neighbors in ", round(time()-time2,digits=2), " s");
println("Connectivity properties:");
println("  Max neighbors: ",maximum(Nneighbors)," (at index ",findfirst(isequal(maximum(Nneighbors)),Nneighbors),")");
println("  Avg neighbors: ",round(sum(Nneighbors)/length(Nneighbors),digits=2));
println("  Min neighbors: ",minimum(Nneighbors)," (at index ",findfirst(isequal(minimum(Nneighbors)),Nneighbors),")");


#neighbors distances and weights
time3 = time();
P = Vector{Array{Float64}}(undef,N);        #relative positions of the neighbors
r2 = Vector{Vector{Float64}}(undef,N);      #relative distances of the neighbors
w2 = Vector{Vector{Float64}}(undef,N);      #neighbors weights
for i=1:N
    P[i] = Array{Float64}(undef,2,Nneighbors[i]);
    r2[i] = Vector{Float64}(undef,Nneighbors[i]);
    w2[i] = Vector{Float64}(undef,Nneighbors[i]);
    for j=1:Nneighbors[i]
        P[i][:,j] = pointcloud[:,neighbors[i][j]]-pointcloud[:,i];
        r2[i][j] = P[i][:,j]'P[i][:,j];
    end
    r2max = maximum(r2[i]);
    for j=1:Nneighbors[i]
        w2[i][j] = exp(-6*r2[i][j]/r2max)^2;
    end
end
w2pde = 2.0;        #least squares weight for the pde
w2bc = 2.0;     #least squares weight for the boundary condition


#least square matrix inversion
A = Vector{Matrix}(undef,N);        #least-squares matrices
C = Vector{Matrix}(undef,N);        #derivatives coefficients matrices
condC = Vector{Float64}(undef,N);       #condition number
for i in internalNodes
    xj = P[i][1,:];
    yj = P[i][2,:];
    V = zeros(Float64,2+2*Nneighbors[i],12);
    for j=1:Nneighbors[i]
        V[j,:] = [1, xj[j], yj[j], xj[j]^2, yj[j]^2, xj[j]*yj[j], 0, 0, 0, 0, 0, 0];
        V[j+Nneighbors[i],:] = [0, 0, 0, 0, 0, 0, 1, xj[j], yj[j], xj[j]^2, yj[j]^2, xj[j]*yj[j]];
    end
    V[1+2*Nneighbors[i],:] = [0, 0, 0, 2*(2+λ/μ), 2, 0, 0, 0, 0, 0, 0, 1+λ/μ];
    V[2+2*Nneighbors[i],:] = [0, 0, 0, 0, 0, 1+λ/μ, 0, 0, 0, 2, 2*(2+λ/μ), 0];
    W = Diagonal(vcat(w2[i],w2[i],w2pde,w2pde));
    A[i] = transpose(V)*W*V;
    (Q,R) = qr(A[i]);
    C[i] = inv(R)*transpose(Q)*transpose(V)*W;
    #condC[i] = cond(C[i]);
end
for i in boundaryNodes
    xj = P[i][1,:];
    yj = P[i][2,:];
    V = zeros(Float64,4+2*Nneighbors[i],12);
    for j=1:Nneighbors[i]
        V[j,:] = [1, xj[j], yj[j], xj[j]^2, yj[j]^2, xj[j]*yj[j], 0, 0, 0, 0, 0, 0];
        V[j+Nneighbors[i],:] = [0, 0, 0, 0, 0, 0, 1, xj[j], yj[j], xj[j]^2, yj[j]^2, xj[j]*yj[j]];
    end
    V[1+2*Nneighbors[i],:] = [0, 0, 0, 2*(2+λ/μ), 2, 0, 0, 0, 0, 0, 0, 1+λ/μ];
    V[2+2*Nneighbors[i],:] = [0, 0, 0, 0, 0, 1+λ/μ, 0, 0, 0, 2, 2*(2+λ/μ), 0];
    if !isnan(uD[i])
        V[3+2*Nneighbors[i],:] = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
    else
        #V[3+2*Nneighbors[i],:] = [0, normals[1,i], normals[2,i], 0, 0, 0, 0, 0, 0, 0, 0, 0];
        #V[3+2*Nneighbors[i],:] = [0, normals[1,i]*(λ/μ)*(1-ν)/ν, 2*normals[2,i], 0, 0, 0, 0, 0, normals[1,i]*(λ/μ), 0, 0, 0];
        V[3+2*Nneighbors[i],:] = [0, normals[1,i]*(2+λ/μ), normals[2,i], 0, 0, 0, 0, normals[2,i], normals[1,i]*(λ/μ), 0, 0, 0];
    end
    if !isnan(vD[i])
        V[4+2*Nneighbors[i],:] = [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0];
    else
        #V[4+2*Nneighbors[i],:] = [0, 0, 0, 0, 0, 0, 0, normals[1,i], normals[2,i], 0, 0, 0];
        #V[4+2*Nneighbors[i],:] = [0, normals[2,i]*(λ/μ), 0, 0, 0, 0, 0, 2*normals[1,i], normals[2,i]*(λ/μ)*(1-ν)/ν, 0, 0, 0];
        V[4+2*Nneighbors[i],:] = [0, normals[2,i]*(λ/μ), normals[1,i], 0, 0, 0, 0, normals[1,i], normals[2,i]*(2+λ/μ), 0, 0, 0];
    end
    W = Diagonal(vcat(w2[i],w2[i],w2pde,w2pde,w2bc,w2bc));
    A[i] = transpose(V)*W*V;
    (Q,R) = qr(A[i]);
    C[i] = inv(R)*transpose(Q)*transpose(V)*W;
    #condC[i] = cond(C[i]);
end
println("Inverted least-squares matrices in ", round(time()-time3,digits=2), " s");

#condition number distribution plot
figure();
plt.hist(condC,10);
title("Condition number distribution");
xlabel("cond(C)");
ylabel("Absolute frequency");
display(gcf());


#matrix assembly
time4 = time();
rows = Int[];
cols = Int[];
vals = Float64[];
for i=1:N
    #u equation
    push!(rows, i);
    push!(cols, i);
    push!(vals, 1);
    for j=1:Nneighbors[i]
        push!(rows, i);
        push!(cols, neighbors[i][j]);
        push!(vals, -C[i][1,j]);
        push!(rows, i);
        push!(cols, N+neighbors[i][j]);
        push!(vals, -C[i][1,j+Nneighbors[i]]);
    end
    #v equation
    push!(rows, N+i);
    push!(cols, N+i);
    push!(vals, 1);
    for j=1:Nneighbors[i]
        push!(rows, N+i);
        push!(cols, neighbors[i][j]);
        push!(vals, -C[i][7,j]);
        push!(rows, N+i);
        push!(cols, N+neighbors[i][j]);
        push!(vals, -C[i][7,j+Nneighbors[i]]);
    end
end
M = sparse(rows,cols,vals,2*N,2*N);
println("Completed matrix assembly in ", round(time()-time4,digits=2), " s");


#linear system solution
time5 = time();
b = zeros(2*N);       #rhs vector
for i in internalNodes
    b[i] = 0;
end
for i in boundaryNodes
    b[i] = 0;
    b[N+i] = 0;
    if !isnan(uD[i])
        b[i] += C[i][1,end-1]*uD[i];
        b[N+i] += C[i][7,end-1]*uD[i];
    else
        b[i] += C[i][1,end-1]*uN[i]/μ;
        b[N+i] += C[i][7,end-1]*uN[i]/μ;
    end
    if !isnan(vD[i])
        b[N+i] += C[i][7,end]*vD[i];
        b[i] += C[i][1,end]*vD[i];
    else
        b[N+i] += C[i][7,end]*vN[i]/μ;
        b[i] += C[i][1,end]*vN[i]/μ;
    end
end
sol = M\b;
println("Linear system solved in ", round(time()-time5,digits=2), " s");

#displacement plot
u = sol[1:N];
v = sol[N+1:end];
#=figure();
scatter(pointcloud[1,:],pointcloud[2,:],c=u,cmap="Oranges");
colorbar();
title("Hole plate - x displacement");
axis("equal");
display(gcf());=#


#σxx stress
dudx = Vector{Float64}(undef,N);
dvdy = Vector{Float64}(undef,N);
σxx = Vector{Float64}(undef,N);
for i=1:N
    dudx[i] = 0.0;
    dvdy[i] = 0.0;
    for j=1:Nneighbors[i]
        dudx[i] += C[i][2,j]*u[neighbors[i][j]] + C[i][2,j+Nneighbors[i]]*v[neighbors[i][j]];
        dvdy[i] += C[i][9,j]*u[neighbors[i][j]] + C[i][9,j+Nneighbors[i]]*v[neighbors[i][j]];
    end
end
for i in boundaryNodes
    if !isnan(uD[i])
        dudx[i] += C[i][2,end-1]*uD[i];
        dvdy[i] += C[i][9,end-1]*uD[i];
    else
        dudx[i] += C[i][2,end-1]*uN[i]/μ;
        dvdy[i] += C[i][9,end-1]*uN[i]/μ;
    end
    if !isnan(vD[i])
        dudx[i] += C[i][2,end]*vD[i];
        dvdy[i] += C[i][9,end]*vD[i];
    else
        dudx[i] += C[i][2,end]*vN[i]/μ;
        dvdy[i] += C[i][9,end]*vN[i]/μ;
    end
end
#σxx = dudx*(1-ν)*E/((1+ν)*(1-2ν)) + dvdy*ν*E/((1+ν)*(1-2ν));
#σxx = E*dudx;
#σxx = dudx*λ*(1-ν)/ν;
#σxx = dudx*λ*(1-ν)/ν + dvdy*λ;
σxx = dudx*(2μ+λ) + dvdy*λ;

#stress plot
figure();
scatter(pointcloud[1,:],pointcloud[2,:],c=σxx,cmap="jet");
colorbar();
title("Hole plate - σxx stress");
axis("equal");
display(gcf());


#validation plot - σxx(0,y) stress
idxPlot = findall(pointcloud[1,:].==0);
y_exactsol = collect(a:0.01:l2);
exactsol = 0*y_exactsol;
for i=1:length(exactsol)
    exactsol[i] = exact_σxx(0,y_exactsol[i]);
end
figure();
plot(pointcloud[2,idxPlot],σxx[idxPlot],"r.",label="GFDM");
plot(y_exactsol,exactsol,"k-",linewidth=1.0,label="Analytical");
title("σxx stress @x=0");
legend(loc="upper right");
xlabel("y coordinate");
ylabel("σxx stress");
axis("equal");
display(gcf());

#validation plot - σxx(x,0) stress
idxPlot = findall(pointcloud[2,:].==0);
y_exactsol = collect(a:0.01:l1);
exactsol = 0*y_exactsol;
for i=1:length(exactsol)
    exactsol[i] = exact_σxx(y_exactsol[i],0);
end
error_σxx = σxx-exact_σxx.(pointcloud[1,:],pointcloud[2,:]);
σxx_rmse = sqrt(sum((error_σxx).^2)/N);
figure();
plot(pointcloud[1,idxPlot],σxx[idxPlot],"r.",label="GFDM");
plot(y_exactsol,exactsol,"k-",linewidth=1.0,label="Analytical");
title("σxx stress @y=0");
legend(loc="upper right");
xlabel("x coordinate");
ylabel("σxx stress");
axis("equal");
display(gcf());

#exact solution plot
#=figure();
scatter(pointcloud[1,:],pointcloud[2,:],c=exact_σxx.(pointcloud[1,:],pointcloud[2,:]),cmap="jet");
colorbar();
title("Exact solution");
axis("equal");
display(gcf());=#
